from tool.logger import Logger
from data.dl_getter import get_dl_tr, get_dl_vl, cycle
from model.model_getter import get_model
from train.opt import get_opt
from model.model_io import _resume, _reload, _save_model


class MLBase:

    def __init__(self, args=None, other=None):
        if other is None:
            self.args = args
            self.tr_dl = get_dl_tr(args)
            self.vl_dl = get_dl_vl(args)
            self.uc_dl = None
            self.model = get_model(args)
            opt_target = self.model.head if args.method == 'finetune' else self.model
            self.optimizer = get_opt(args, opt_target)
            self.replay_buffer = None
        else:
            self.args = other.args
            self.tr_dl = other.tr_dl
            self.vl_dl = other.vl_dl
            self.uc_dl = other.uc_dl
            self.model = other.model
            self.optimizer = other.optimizer
            self.replay_buffer = other.replay_buffer

    def save_model(self, loss, best_acc):
        last_epoch_f = (self.epoch == self.args.epochs - 1)
        if (loss < 1e3 and (self.epoch+1)%self.args.save_freq == 0) or \
            last_epoch_f:
            _save_model(self.args, best_acc, self.model, self.optimizer,
                             self.replay_buffer, epoch=self.epoch+1)

    def resume(self):
        return _resume(self.args, self.model, self.optimizer, self.replay_buffer)

    def reload(self, loss):
        _reload(self.args, loss, self.model, self.optimizer, self.replay_buffer)


